# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""

import datetime
import json
import random
import csv
import pandas as pd
import io

import pytest
import prison
from sqlalchemy.sql import func  # noqa: F401
from unittest import mock

from flask_appbuilder.security.sqla.models import Role
from tests.integration_tests.test_app import app
from superset import db, sql_lab
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database  # noqa: F401
from superset.utils.database import get_example_database, get_main_database  # noqa: F401
from superset.utils import core as utils
from superset.models.sql_lab import Query

from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import (
    ADMIN_USERNAME,
    GAMMA_SQLLAB_NO_DATA_USERNAME,
)
from tests.integration_tests.fixtures.birth_names_dashboard import load_birth_names_data  # noqa: F401
from tests.integration_tests.fixtures.users import create_gamma_sqllab_no_data  # noqa: F401

QUERIES_FIXTURE_COUNT = 10


class TestSqlLabApi(SupersetTestCase):
    @pytest.mark.usefixtures("create_gamma_sqllab_no_data")
    @mock.patch.dict(
        "superset.extensions.feature_flag_manager._feature_flags",
        {"SQLLAB_BACKEND_PERSISTENCE": False},
        clear=True,
    )
    def test_get_from_empty_bootsrap_data(self):
        if utils.backend() == "postgresql":
            # failing
            return

        self.login(GAMMA_SQLLAB_NO_DATA_USERNAME)
        resp = self.client.get("/api/v1/sqllab/")
        assert resp.status_code == 200
        data = json.loads(resp.data.decode("utf-8"))
        result = data.get("result")
        assert result["active_tab"] is None  # noqa: E711
        assert result["tab_state_ids"] == []
        self.assertEqual(len(result["databases"]), 0)

    @mock.patch.dict(
        "superset.extensions.feature_flag_manager._feature_flags",
        {"SQLLAB_BACKEND_PERSISTENCE": False},
        clear=True,
    )
    def test_get_from_bootstrap_data_for_non_persisted_tab_state(self):
        self.login(ADMIN_USERNAME)
        # create a tab
        data = {
            "queryEditor": json.dumps(
                {
                    "title": "Untitled Query 1",
                    "dbId": 1,
                    "schema": None,
                    "autorun": False,
                    "sql": "SELECT ...",
                    "queryLimit": 1000,
                }
            )
        }
        self.get_json_resp("/tabstateview/", data=data)
        resp = self.client.get("/api/v1/sqllab/")
        assert resp.status_code == 200
        data = json.loads(resp.data.decode("utf-8"))
        result = data.get("result")
        assert result["active_tab"] is None  # noqa: E711
        assert result["tab_state_ids"] == []

    @pytest.mark.usefixtures("load_birth_names_data")
    @mock.patch.dict(
        "superset.extensions.feature_flag_manager._feature_flags",
        {"SQLLAB_BACKEND_PERSISTENCE": True},
        clear=True,
    )
    def test_get_from_bootstrap_data_with_latest_query(self):
        self.login(ADMIN_USERNAME)

        # create a tab
        data = {
            "queryEditor": json.dumps(
                {
                    "title": "Untitled Query 1",
                    "dbId": 1,
                    "schema": None,
                    "autorun": False,
                    "sql": "SELECT ...",
                    "queryLimit": 1000,
                }
            )
        }
        resp = self.get_json_resp("/tabstateview/", data=data)
        tab_state_id = resp["id"]

        # we should have only 1 query returned, since the second one is not
        # associated with any tabs
        resp = self.get_json_resp("/api/v1/sqllab/")
        result = resp["result"]
        self.assertEqual(result["active_tab"]["id"], tab_state_id)

    @mock.patch.dict(
        "superset.extensions.feature_flag_manager._feature_flags",
        {"SQLLAB_BACKEND_PERSISTENCE": True},
        clear=True,
    )
    def test_deleted_tab(self):
        username = "admin"
        self.login(username)
        data = {
            "queryEditor": json.dumps(
                {
                    "title": "Untitled Query 2",
                    "dbId": 1,
                    "schema": None,
                    "autorun": False,
                    "sql": "SELECT ...",
                    "queryLimit": 1000,
                }
            )
        }
        resp = self.get_json_resp("/tabstateview/", data=data)
        tab_state_id = resp["id"]
        resp = self.client.delete("/tabstateview/" + str(tab_state_id))
        assert resp.status_code == 200
        resp = self.client.get("/tabstateview/" + str(tab_state_id))
        assert resp.status_code == 404
        resp = self.client.put(
            "/tabstateview/" + str(tab_state_id),
            json=data,
        )
        assert resp.status_code == 404

    @mock.patch.dict(
        "superset.extensions.feature_flag_manager._feature_flags",
        {"SQLLAB_BACKEND_PERSISTENCE": True},
        clear=True,
    )
    def test_delete_tab_already_removed(self):
        username = "admin"
        self.login(username)
        data = {
            "queryEditor": json.dumps(
                {
                    "title": "Untitled Query 3",
                    "dbId": 1,
                    "schema": None,
                    "autorun": False,
                    "sql": "SELECT ...",
                    "queryLimit": 1000,
                }
            )
        }
        resp = self.get_json_resp("/tabstateview/", data=data)
        tab_state_id = resp["id"]
        resp = self.client.delete("/tabstateview/" + str(tab_state_id))
        assert resp.status_code == 200
        resp = self.client.delete("/tabstateview/" + str(tab_state_id))
        assert resp.status_code == 404

    def test_get_access_denied(self):
        new_role = Role(name="Dummy Role", permissions=[])
        db.session.add(new_role)
        db.session.commit()
        unauth_user = self.create_user(
            "unauth_user1",
            "password",
            "Dummy Role",
            email="unauth_user1@superset.org",  # noqa: F541
        )
        self.login(username="unauth_user1", password="password")
        rv = self.client.get("/api/v1/sqllab/")

        assert rv.status_code == 403

        db.session.delete(unauth_user)
        db.session.delete(new_role)
        db.session.commit()

    def test_estimate_required_params(self):
        self.login(ADMIN_USERNAME)

        rv = self.client.post(
            "/api/v1/sqllab/estimate/",
            json={},
        )
        failed_resp = {
            "message": {
                "sql": ["Missing data for required field."],
                "database_id": ["Missing data for required field."],
            }
        }
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

        data = {"sql": "SELECT 1"}
        rv = self.client.post(
            "/api/v1/sqllab/estimate/",
            json=data,
        )
        failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

        data = {"database_id": 1}
        rv = self.client.post(
            "/api/v1/sqllab/estimate/",
            json=data,
        )
        failed_resp = {"message": {"sql": ["Missing data for required field."]}}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

    def test_estimate_valid_request(self):
        self.login(ADMIN_USERNAME)

        formatter_response = [
            {
                "value": 100,
            }
        ]

        db_mock = mock.Mock()
        db_mock.db_engine_spec = mock.Mock()
        db_mock.db_engine_spec.estimate_query_cost = mock.Mock(return_value=100)
        db_mock.db_engine_spec.query_cost_formatter = mock.Mock(
            return_value=formatter_response
        )

        with mock.patch("superset.commands.sql_lab.estimate.db") as mock_superset_db:
            mock_superset_db.session.query().get.return_value = db_mock

            data = {"database_id": 1, "sql": "SELECT 1"}
            rv = self.client.post(
                "/api/v1/sqllab/estimate/",
                json=data,
            )

        success_resp = {"result": formatter_response}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, success_resp)
        self.assertEqual(rv.status_code, 200)

    def test_format_sql_request(self):
        self.login(ADMIN_USERNAME)

        data = {"sql": "select 1 from my_table"}
        rv = self.client.post(
            "/api/v1/sqllab/format_sql/",
            json=data,
        )
        success_resp = {"result": "SELECT\n  1\nFROM my_table"}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, success_resp)
        self.assertEqual(rv.status_code, 200)

    @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
    def test_execute_required_params(self):
        self.login(ADMIN_USERNAME)
        client_id = f"{random.getrandbits(64)}"[:10]

        data = {"client_id": client_id}
        rv = self.client.post(
            "/api/v1/sqllab/execute/",
            json=data,
        )
        failed_resp = {
            "message": {
                "sql": ["Missing data for required field."],
                "database_id": ["Missing data for required field."],
            }
        }
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

        data = {"sql": "SELECT 1", "client_id": client_id}
        rv = self.client.post(
            "/api/v1/sqllab/execute/",
            json=data,
        )
        failed_resp = {"message": {"database_id": ["Missing data for required field."]}}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

        data = {"database_id": 1, "client_id": client_id}
        rv = self.client.post(
            "/api/v1/sqllab/execute/",
            json=data,
        )
        failed_resp = {"message": {"sql": ["Missing data for required field."]}}
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertDictEqual(resp_data, failed_resp)
        self.assertEqual(rv.status_code, 400)

    @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
    def test_execute_valid_request(self) -> None:
        from superset import sql_lab as core

        core.results_backend = mock.Mock()
        core.results_backend.get.return_value = {}

        self.login(ADMIN_USERNAME)
        client_id = f"{random.getrandbits(64)}"[:10]

        data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id}
        rv = self.client.post(
            "/api/v1/sqllab/execute/",
            json=data,
        )
        resp_data = json.loads(rv.data.decode("utf-8"))
        self.assertEqual(resp_data.get("status"), "success")
        self.assertEqual(rv.status_code, 200)

    @mock.patch(
        "tests.integration_tests.superset_test_custom_template_processors.datetime"
    )
    @mock.patch("superset.sqllab.api.get_sql_results")
    def test_execute_custom_templated(self, sql_lab_mock, mock_dt) -> None:
        mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1))
        self.login(ADMIN_USERNAME)
        sql = "SELECT '$DATE()' as test"
        resp = {
            "status": QueryStatus.SUCCESS,
            "query": {"rows": 1},
            "data": [{"test": "'1970-01-01'"}],
        }
        sql_lab_mock.return_value = resp

        dbobj = self.create_fake_db_for_macros()
        json_payload = dict(database_id=dbobj.id, sql=sql)
        self.get_json_resp(
            "/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload
        )
        assert sql_lab_mock.called
        self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test")

        self.delete_fake_db_for_macros()

    @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
    def test_get_results_with_display_limit(self):
        from superset.commands.sql_lab import results as command

        command.results_backend = mock.Mock()
        self.login(ADMIN_USERNAME)

        data = [{"col_0": i} for i in range(100)]
        payload = {
            "status": QueryStatus.SUCCESS,
            "query": {"rows": 100},
            "data": data,
        }
        # limit results to 1
        expected_key = {"status": "success", "query": {"rows": 100}, "data": data}
        limited_data = data[:1]
        expected_limited = {
            "status": "success",
            "query": {"rows": 100},
            "data": limited_data,
            "displayLimitReached": True,
        }

        query_mock = mock.Mock()
        query_mock.sql = "SELECT *"
        query_mock.database = 1
        query_mock.schema = "superset"

        # do not apply msgpack serialization
        use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
        app.config["RESULTS_BACKEND_USE_MSGPACK"] = False
        serialized_payload = sql_lab._serialize_payload(payload, False)
        compressed = utils.zlib_compress(serialized_payload)
        command.results_backend.get.return_value = compressed

        with mock.patch("superset.commands.sql_lab.results.db") as mock_superset_db:
            mock_superset_db.session.query().filter_by().one_or_none.return_value = (
                query_mock
            )
            # get all results
            arguments = {"key": "key"}
            result_key = json.loads(
                self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}")
            )
            arguments = {"key": "key", "rows": 1}
            result_limited = json.loads(
                self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}")
            )

        self.assertEqual(result_key, expected_key)
        self.assertEqual(result_limited, expected_limited)

        app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack

    @mock.patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None)
    @mock.patch("superset.models.core.Database.get_df")
    def test_export_results(self, get_df_mock: mock.Mock) -> None:
        self.login(ADMIN_USERNAME)

        database = get_example_database()
        query_obj = Query(
            client_id="test",
            database=database,
            tab_name="test_tab",
            sql_editor_id="test_editor_id",
            sql="select * from bar",
            select_sql=None,
            executed_sql="select * from bar limit 2",
            limit=100,
            select_as_cta=False,
            rows=104,
            error_message="none",
            results_key="test_abc",
        )

        db.session.add(query_obj)
        db.session.commit()

        get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]})

        resp = self.get_resp("/api/v1/sqllab/export/test/")
        data = csv.reader(io.StringIO(resp))
        expected_data = csv.reader(io.StringIO("foo\n1\n2"))

        self.assertEqual(list(expected_data), list(data))
        db.session.delete(query_obj)
        db.session.commit()
